{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automatically created module for IPython interactive environment\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 600x600 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Author: Ron Weiss <ronweiss@gmail.com>, Gael Varoquaux\n",
    "# Modified by Thierry Guillemot <thierry.guillemot.work@gmail.com>\n",
    "# License: BSD 3 clause\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn import datasets\n",
    "from sklearn.mixture import GaussianMixture\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "print(__doc__)\n",
    "\n",
    "colors = ['navy', 'turquoise', 'darkorange']\n",
    "\n",
    "\n",
    "def make_ellipses(gmm, ax):\n",
    "    for n, color in enumerate(colors):\n",
    "        if gmm.covariance_type == 'full':\n",
    "            covariances = gmm.covariances_[n][:2, :2]\n",
    "        elif gmm.covariance_type == 'tied':\n",
    "            covariances = gmm.covariances_[:2, :2]\n",
    "        elif gmm.covariance_type == 'diag':\n",
    "            covariances = np.diag(gmm.covariances_[n][:2])\n",
    "        elif gmm.covariance_type == 'spherical':\n",
    "            covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]\n",
    "        v, w = np.linalg.eigh(covariances)\n",
    "        u = w[0] / np.linalg.norm(w[0])\n",
    "        angle = np.arctan2(u[1], u[0])\n",
    "        angle = 180 * angle / np.pi  # convert to degrees\n",
    "        v = 2. * np.sqrt(2.) * np.sqrt(v)\n",
    "        ell = mpl.patches.Ellipse(gmm.means_[n, :2], v[0], v[1],\n",
    "                                  180 + angle, color=color)\n",
    "        ell.set_clip_box(ax.bbox)\n",
    "        ell.set_alpha(0.5)\n",
    "        ax.add_artist(ell)\n",
    "        ax.set_aspect('equal', 'datalim')\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "\n",
    "# Break up the dataset into non-overlapping training (75%) and testing\n",
    "# (25%) sets.\n",
    "skf = StratifiedKFold(n_splits=4)\n",
    "# Only take the first fold.\n",
    "train_index, test_index = next(iter(skf.split(iris.data, iris.target)))\n",
    "\n",
    "\n",
    "X_train = iris.data[train_index]\n",
    "y_train = iris.target[train_index]\n",
    "X_test = iris.data[test_index]\n",
    "y_test = iris.target[test_index]\n",
    "\n",
    "n_classes = len(np.unique(y_train))\n",
    "\n",
    "# Try GMMs using different types of covariances.\n",
    "estimators = {cov_type: GaussianMixture(n_components=n_classes,\n",
    "              covariance_type=cov_type, max_iter=20, random_state=0)\n",
    "              for cov_type in ['spherical', 'diag', 'tied', 'full']}\n",
    "\n",
    "n_estimators = len(estimators)\n",
    "\n",
    "plt.figure(figsize=(3 * n_estimators // 2, 6))\n",
    "plt.subplots_adjust(bottom=.01, top=0.95, hspace=.15, wspace=.05,\n",
    "                    left=.01, right=.99)\n",
    "\n",
    "\n",
    "for index, (name, estimator) in enumerate(estimators.items()):\n",
    "    # Since we have class labels for the training data, we can\n",
    "    # initialize the GMM parameters in a supervised manner.\n",
    "    estimator.means_init = np.array([X_train[y_train == i].mean(axis=0)\n",
    "                                    for i in range(n_classes)])\n",
    "\n",
    "    # Train the other parameters using the EM algorithm.\n",
    "    estimator.fit(X_train)\n",
    "\n",
    "    h = plt.subplot(2, n_estimators // 2, index + 1)\n",
    "    make_ellipses(estimator, h)\n",
    "\n",
    "    for n, color in enumerate(colors):\n",
    "        data = iris.data[iris.target == n]\n",
    "        plt.scatter(data[:, 0], data[:, 1], s=0.8, color=color,\n",
    "                    label=iris.target_names[n])\n",
    "    # Plot the test data with crosses\n",
    "    for n, color in enumerate(colors):\n",
    "        data = X_test[y_test == n]\n",
    "        plt.scatter(data[:, 0], data[:, 1], marker='x', color=color)\n",
    "\n",
    "    y_train_pred = estimator.predict(X_train)\n",
    "    train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100\n",
    "    plt.text(0.05, 0.9, 'Train accuracy: %.1f' % train_accuracy,\n",
    "             transform=h.transAxes)\n",
    "\n",
    "    y_test_pred = estimator.predict(X_test)\n",
    "    test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100\n",
    "    plt.text(0.05, 0.8, 'Test accuracy: %.1f' % test_accuracy,\n",
    "             transform=h.transAxes)\n",
    "\n",
    "    plt.xticks(())\n",
    "    plt.yticks(())\n",
    "    plt.title(name)\n",
    "\n",
    "plt.legend(scatterpoints=1, loc='lower right', prop=dict(size=12))\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
